/-
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module

prelude
public import Lean.Meta.Match.MatcherApp.Basic
public import Lean.Meta.Tactic.Apply
public import Lean.Meta.Tactic.Generalize
public import Lean.Meta.Tactic.Simp.Types
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.Tactic.SplitIf

public section

namespace Lean.Meta
namespace Split

def getSimpMatchContext : MetaM Simp.Context := do
   Simp.mkContext
      (simpTheorems   := {})
      (congrTheorems := (← getSimpCongrTheorems))
      (config        := { Simp.neutralConfig with dsimp := false, etaStruct := .none, letToHave := true })

def simpMatch (e : Expr) : MetaM Simp.Result := do
  let discharge? ← SplitIf.mkDischarge?
  (·.1) <$> Simp.main e (← getSimpMatchContext) (methods := { pre, discharge? })
where
  pre (e : Expr) : SimpM Simp.Step := do
    unless (← isMatcherApp e) do
      return Simp.Step.continue
    let matcherDeclName := e.getAppFn.constName!
    -- First try to reduce matcher
    match (← reduceRecMatcher? e) with
    | some e' => return Simp.Step.done { expr := e' }
    | none    => Simp.simpMatchCore matcherDeclName e

def simpMatchTarget (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
  let target ← instantiateMVars (← mvarId.getType)
  let r ← simpMatch target
  applySimpResultToTarget mvarId target r

private def simpMatchCore (matchDeclName : Name) (matchEqDeclName : Name) (e : Expr) : MetaM Simp.Result := do
  let discharge? ← SplitIf.mkDischarge?
  (·.1) <$> Simp.main e (← getSimpMatchContext) (methods := { pre, discharge? })
where
  pre (e : Expr) : SimpM Simp.Step := do
    if e.isAppOf matchDeclName then
      -- First try to reduce matcher
      match (← reduceRecMatcher? e) with
      | some e' => return .done { expr := e' }
      | none    =>
      -- Try lemma
      let simpTheorem := {
        origin := .decl matchEqDeclName
        proof := mkConst matchEqDeclName
        rfl := (← isRflTheorem matchEqDeclName)
      }
      match (← withReducible <| Simp.tryTheorem? e simpTheorem) with
      | none => return .continue
      | some r => return .done r
    else
      return .continue

private def simpMatchTargetCore (mvarId : MVarId) (matchDeclName : Name) (matchEqDeclName : Name) : MetaM MVarId := do
  mvarId.withContext do
    let target ← instantiateMVars (← mvarId.getType)
    let r ← simpMatchCore matchDeclName matchEqDeclName target
    match r.proof? with
    | some proof => mvarId.replaceTargetEq r.expr proof
    | none => mvarId.replaceTargetDefEq r.expr

private partial def withEqs (lhs rhs : Array Expr) (k : Array Expr → Array Expr → MetaM α) : MetaM α := do
  go 0 #[] #[]
where
  go (i : Nat) (hs : Array Expr) (rfls : Array Expr) : MetaM α := do
    if i < lhs.size then
      withLocalDeclD (← mkFreshUserName `heq) (← mkEqHEq lhs[i]! rhs[i]!) fun h => do
        let rfl ← if (← inferType h).isEq then mkEqRefl lhs[i]! else mkHEqRefl lhs[i]!
        go (i+1) (hs.push h) (rfls.push rfl)
    else
      k hs rfls

/--
Internal exception for discriminant generalization failures due to type errors.
-/
builtin_initialize discrGenExId : InternalExceptionId ←
  registerInternalExceptionId `discrGeneralizationFailure

def isDiscrGenException (ex : Exception) : Bool :=
  match ex with
  | .internal id => id == discrGenExId
  | _ => false

/-- Throws an error with message `msg` with a hint that a foreign metaprogram may be misusing `split` internals. -/
private def throwInternalMisuseError [Monad m] [MonadError m] (msg : MessageData) : m α :=
  throwError msg ++ .note m!"This error typically occurs when the `split` tactic's internal functions have been used in a new metaprogram"

/--
Split works best if all discriminants are already free variables. If they are not, it will generalize
them, but that may fail if the motive is dependent. So to avoid that, we first generalize all
non-FVar discriminants that are propositions; because of proof irrelevance, that's much simpler.
-/
private partial def generalizeMatchPropDiscrs (mvarId : MVarId) (discrs : Array Expr) : MetaM (Array Expr × MVarId) := mvarId.withContext do
  let mut mvarId := mvarId
  let mut discrs' := #[]
  for discr in discrs do
    if !discr.isFVar then
      if(← isProof discr) then
        let (fvarIds, mvarId') ← mvarId.generalize #[{expr := discr}]
        mvarId := mvarId'
        discrs' := discrs'.push (mkFVar fvarIds[0]!)
        continue
    discrs' := discrs'.push discr
  return (discrs', mvarId)

/--
  This method makes sure each discriminant is a free variable.
  Return the tuple `(discrsNew, discrEqs, mvarId)`. `discrsNew` in an array representing the new discriminants, `discrEqs` is an array of auxiliary equality hypotheses
  that connect the new discriminants to the original terms they represent.
  Remark: `discrEqs.size ≤ discrsNew.size`

  Remark:
    We should only generalize `discrs` occurrences as `match`-expression discriminants.
    For example, given the following goal.
    ```
    x : Nat
    ⊢ (match g x with
        | 0 => 1
        | Nat.succ y => g x) =
        2 * x + 1
    ```
    we should not generalize the `g x` in the rhs of the second alternative, and the two resulting goals
    for the `split` tactic  should be
    ```
    case h_1
    x x✝ : Nat
    h✝ : g x = 0
    ⊢ 1 = 2 * x + 1

    case h_2
    x x✝ y✝ : Nat
    h✝ : g x = Nat.succ y✝
    ⊢ g x = 2 * x + 1
    ```
-/
private partial def generalizeMatchDiscrs (mvarId : MVarId) (matcherDeclName : Name) (motiveType : Expr) (discrs : Array Expr) : MetaM (Array FVarId × Array FVarId × MVarId) := mvarId.withContext do
  if discrs.all (·.isFVar) then
    trace[split.debug] "no need to generalize discriminants, all are fvars"
    return (discrs.map (·.fvarId!), #[], mvarId)
  let some matcherInfo ← getMatcherInfo? matcherDeclName | unreachable!
  let numDiscrEqs := matcherInfo.getNumDiscrEqs -- Number of `h : discr = pattern` equations
  let (targetNew, rfls) ←
    forallTelescope motiveType fun discrVars _ =>
    withEqs discrs discrVars fun eqs rfls => do
      let foundRef ← IO.mkRef false
      let rec mkNewTarget (e : Expr) : MetaM Expr := do
        let pre (e : Expr) : MetaM TransformStep := do
          if !e.isAppOfArity matcherDeclName matcherInfo.arity then
            return .continue
          let some matcherApp ← matchMatcherApp? e | return .continue
          for matcherDiscr in matcherApp.discrs, discr in discrs do
            unless matcherDiscr == discr do
              trace[split.debug] "discr mismatch {matcherDiscr} != {discr}"
              return .continue
          let matcherApp := { matcherApp with discrs := discrVars }
          foundRef.set true
          let mut altsNew := #[]
          for h : i in *...matcherApp.alts.size do
            let alt := matcherApp.alts[i]
            let altNumParams := matcherApp.altNumParams[i]!
            let altNew ← lambdaTelescope alt fun xs body => do
              if xs.size < altNumParams || xs.size < numDiscrEqs then
                throwError m!"Internal error in `split` tactic: Encountered an unexpected `match` expression alternative"
                  ++ .note m!"This error typically occurs when the `match` expression has been constructed using metaprogramming."
              let body ← mkLambdaFVars xs[altNumParams...*] (← mkNewTarget body)
              let ys  := xs[*...(altNumParams - numDiscrEqs)]
              if numDiscrEqs == 0 then
                mkLambdaFVars ys body
              else
                let altEqs := xs[(altNumParams - numDiscrEqs)...altNumParams]
                withNewAltEqs matcherInfo eqs altEqs fun altEqsNew subst => do
                  let body := body.replaceFVars altEqs subst
                  mkLambdaFVars (ys++altEqsNew) body
            altsNew := altsNew.push altNew
          return .done { matcherApp with alts := altsNew }.toExpr
        transform (← instantiateMVars e) pre
      let targetNew ← mkNewTarget (← mvarId.getType)
      unless (← foundRef.get) do
        throwInternalMisuseError m!"Internal error in `split` tactic: Failed to find match-expression discriminants"
      let targetNew ← mkForallFVars (discrVars ++ eqs) targetNew
      try check targetNew
      catch e =>
        trace[split.debug] "targetNew not type correct:{indentExpr targetNew}\n{e.toMessageData}"
        throw <| Exception.internal discrGenExId
      return (targetNew, rfls)
    let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew (← mvarId.getTag)
    trace[split.debug] "targetNew:\n{mvarNew.mvarId!}"
    mvarId.assign (mkAppN (mkAppN mvarNew discrs) rfls)
    let (discrs', mvarId') ← mvarNew.mvarId!.introNP discrs.size
    let (discrEqs, mvarId') ← mvarId'.introNP discrs.size
    return (discrs', discrEqs, mvarId')
where
  /--
    - `eqs` are free variables `h_eq : discr = discrVar`. `eqs.size == discrs.size`
    - `altEqs` are free variables of the form `h_altEq : discr = pattern`. `altEqs.size = numDiscrEqs ≤ discrs.size`
    This method executes `k altEqsNew subst` where
    - `altEqsNew` are fresh free variables of the form `h_altEqNew : discrVar = pattern`
    - `subst` are terms of the form `h_eq.trans h_altEqNew : discr = pattern`. We use `subst` later to replace occurrences of `h_altEq` with `h_eq.trans h_altEqNew`.
   -/
  withNewAltEqs (matcherInfo : MatcherInfo) (eqs : Array Expr) (altEqs : Array Expr) (k : Array Expr → Array Expr → MetaM Expr) : MetaM Expr := do
    let eqs' := (eqs.zip matcherInfo.discrInfos).filterMap fun (eq, info) => if info.hName?.isNone then none else some eq
    -- `eqs'.size == altEqs.size ≤ eqs.size`
    let rec go (i : Nat) (altEqsNew : Array Expr) (subst : Array Expr) : MetaM Expr := do
      if i < altEqs.size then
        let altEqDecl ← getFVarLocalDecl altEqs[i]!
        let eq := eqs'[i]!
        let eqType ← inferType eq
        let altEqType := altEqDecl.type
        match eqType.eq?, altEqType.eq? with
        | some (_, _, discrVar), some (_, _ /- discr -/, pattern) =>
          withLocalDeclD altEqDecl.userName (← mkEq discrVar pattern) fun altEqNew => do
            go (i+1) (altEqsNew.push altEqNew) (subst.push (← mkEqTrans eq altEqNew))
        | _, _ =>
        match eqType.heq?, altEqType.heq? with
        | some (_, _, _, discrVar), some (_, _ /- discr -/, _, pattern) =>
          withLocalDeclD altEqDecl.userName (← mkHEq discrVar pattern) fun altEqNew => do
            go (i+1) (altEqsNew.push altEqNew) (subst.push (← mkHEqTrans eq altEqNew))
        | _, _ =>
          throwInternalMisuseError m!"Internal error in `split` tactic: Encountered unexpected auxiliary equalities created to generalize `match`-expression discriminant"
      else
        k altEqsNew subst
    go 0 #[] #[]

private def substDiscrEqs (mvarId : MVarId) (fvarSubst : FVarSubst) (discrEqs : Array FVarId) : MetaM MVarId := mvarId.withContext do
  let mut mvarId := mvarId
  let mut fvarSubst := fvarSubst
  for fvarId in discrEqs do
    if let .fvar fvarId := fvarSubst.apply (mkFVar fvarId) then
      let (fvarId, mvarId') ← heqToEq mvarId fvarId
      match (← substCore? mvarId' fvarId (symm := false) fvarSubst) with
      | some (fvarSubst', mvarId') => mvarId := mvarId'; fvarSubst := fvarSubst'
      | none =>
      match (← substCore? mvarId' fvarId (symm := true) fvarSubst) with
      | some (fvarSubst', mvarId') => mvarId := mvarId'; fvarSubst := fvarSubst'
      | none => mvarId := mvarId'
  return mvarId

def applyMatchSplitter (mvarId : MVarId) (matcherDeclName : Name) (us : Array Level) (params : Array Expr) (discrs : Array Expr) : MetaM (List MVarId) := do
  let some info ← getMatcherInfo? matcherDeclName | throwInternalMisuseError m!"Internal error in `split` tactic: `{matcherDeclName}` is not an auxiliary declaration used to encode `match`-expressions"
  let matchEqns ← Match.getEquationsFor matcherDeclName
  -- splitterPre does not have the correct universe elimination level, but this is fine, we only use it to compute the `motiveType`,
  -- and we only care about the `motiveType` arguments, and not the resulting `Sort u`.
  let splitterPre := mkAppN (mkConst matchEqns.splitterName us.toList) params
  let motiveType := (← whnfForall (← inferType splitterPre)).bindingDomain!
  trace[split.debug] "applyMatchSplitter\n{mvarId}"
  let (discrs, mvarId) ← generalizeMatchPropDiscrs mvarId discrs
  trace[split.debug] "after generalizeMatchPropDiscrs\n{mvarId}"
  let (discrFVarIds, discrEqs, mvarId) ← generalizeMatchDiscrs mvarId matcherDeclName motiveType discrs
  trace[split.debug] "after generalizeMatchDiscrs\n{mvarId}"
  let mvarId ← generalizeTargetsEq mvarId motiveType (discrFVarIds.map mkFVar)
  mvarId.withContext do trace[split.debug] "discrEqs after generalizeTargetsEq: {discrEqs.map mkFVar}"
  trace[split.debug] "after generalize\n{mvarId}"
  let numEqs := discrs.size
  let (discrFVarIdsNew, mvarId) ← mvarId.introN discrs.size
  trace[split.debug] "after introN\n{mvarId}"
  let discrsNew := discrFVarIdsNew.map mkFVar
  let mvarType ← mvarId.getType
  let elimUniv ← mvarId.withContext <| getLevel mvarType
  let us ← if let some uElimPos := info.uElimPos? then
    pure <| us.set! uElimPos elimUniv
  else
    unless elimUniv.isZero do
      throwError "`split` tactic failed to split a match-expression: The splitter auxiliary theorem \
        `{.ofConstName matchEqns.splitterName}` can only eliminate into `Prop`"
    pure us
  let splitter := mkAppN (mkConst matchEqns.splitterName us.toList) params
  mvarId.withContext do
    let motive ← mkLambdaFVars discrsNew mvarType
    let splitter := mkAppN (mkApp splitter motive) discrsNew
    check splitter
    trace[split.debug] "after check splitter"
    let mvarIds ← mvarId.applyN splitter matchEqns.size
    let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
      let altInfo := matchEqns.splitterMatchInfo.altInfos[i]!
      let mvarId ←
        if altInfo.hasUnitThunk then
          trace[split.debug] "introducing unit param for alt {(i : Nat)}"
          let (unitFvarId, mvarId) ← mvarId.intro1
          mvarId.tryClear unitFvarId
        else
          let (_, mvarId) ← mvarId.introN (altInfo.numFields + altInfo.numOverlaps)
          pure mvarId
      trace[split.debug] "before unifyEqs\n{mvarId}"
      match (← Cases.unifyEqs? (info.getNumDiscrEqs + numEqs) mvarId {}) with
      | none   => return (i+1, mvarIds) -- case was solved
      | some (mvarId, fvarSubst) =>
        trace[split.debug] "after unifyEqs\n{mvarId}"
        let mvarId ← substDiscrEqs mvarId fvarSubst discrEqs
        return (i+1, mvarId::mvarIds)
    return mvarIds.reverse

def mkDiscrGenErrorMsg (e : Expr) : MessageData :=
  m!"`split` tactic failed to generalize discriminant(s) at{indentExpr e}\nresulting expression was not type correct\npossible solution: generalize discriminant(s) manually before using `split`"

def throwDiscrGenError (e : Expr) : MetaM α :=
  throwError (mkDiscrGenErrorMsg e)

def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do
  let some app ← matchMatcherApp? e | throwInternalMisuseError m!"Internal error in `split` tactic: Match application expected{indentExpr e}"
  let matchEqns ← Match.getEquationsFor app.matcherName
  let mvarIds ← applyMatchSplitter mvarId app.matcherName app.matcherLevels app.params app.discrs
  let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
    let mvarId ← simpMatchTargetCore mvarId app.matcherName matchEqns.eqnNames[i]!
    return (i+1, mvarId::mvarIds)
  return mvarIds.reverse

end Split

open Split

/--
Splits an `if-then-else` of `match`-expression in the goal target.
If `useNewSemantics` is `true`, the flag `backward.split` is ignored. Recall this flag only affects the split of `if-then-else` expressions.
-/
partial def splitTarget? (mvarId : MVarId) (splitIte := true) (useNewSemantics := false) : MetaM (Option (List MVarId)) := commitWhenSome? do mvarId.withContext do
  let target ← instantiateMVars (← mvarId.getType)
  let rec go (badCases : ExprSet) : MetaM (Option (List MVarId)) := do
    if let some e ← findSplit? target (if splitIte then .both else .match) badCases then
      if e.isIte || e.isDIte then
        return (← splitIfTarget? mvarId (useNewSemantics := useNewSemantics)).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId]
      else
        try
          splitMatch mvarId e
        catch ex =>
          if isDiscrGenException ex then
            trace[split.failure] mkDiscrGenErrorMsg e
          else
            trace[split.failure] "`split` tactic failed at{indentExpr e}\n{ex.toMessageData}"
          go (badCases.insert e)
    else
      trace[split.debug] "did not find term to split\n{MessageData.ofGoal mvarId}"
      return none
  go {}

def splitLocalDecl? (mvarId : MVarId) (fvarId : FVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do
  mvarId.withContext do
    if let some e ← findSplit? (← instantiateMVars (← inferType (mkFVar fvarId))) then
      if e.isIte || e.isDIte then
        return (← splitIfLocalDecl? mvarId fvarId).map fun (mvarId₁, mvarId₂) => [mvarId₁, mvarId₂]
      else
        let result? ← commitWhenSome? do try
          let (fvarIds, mvarId) ← mvarId.revert #[fvarId]
          let num := fvarIds.size
          let mvarIds ← splitMatch mvarId e
          let mvarIds ← mvarIds.mapM fun mvarId => return (← mvarId.introNP num).2
          return some mvarIds
        catch ex =>
          if isDiscrGenException ex then
            return none
          else
            throw ex
        if result?.isSome then
          return result?
        -- Generalization failed, if `fvarId` is a let-decl or has forward dependencies, we try to `assert` a copy and try again
        let localDecl ← fvarId.getDecl
        if (← pure localDecl.isLet <||> exprDependsOn (← mvarId.getType) fvarId <||> fvarId.hasForwardDeps) then
          try
            let mvarId ← mvarId.assert localDecl.userName localDecl.type localDecl.toExpr
            let mvarIds ← splitMatch mvarId e
            let mvarIds ← mvarIds.mapM fun mvarId => return (← mvarId.intro1P).2
            return some mvarIds
          catch ex =>
            if isDiscrGenException ex then
              throwDiscrGenError e
            else
              throw ex
        throwDiscrGenError e
    else
      return none

builtin_initialize
  registerTraceClass `split.debug
  registerTraceClass `split.failure

end Lean.Meta
